Fine Tuning a ResNet Model for Image Classification¶
Our goal is to train a classification model capable of identifying different kinds of mushrooms using a pretrained ResNet model. This transfer learning approach is effective with limited data, leveraging the knowledge embedded in ResNet50 which was trained on much larger datasets.
The model is trained with around 5000 images, applying conservative augmentation techniques.
Initially, the model is tuned with its initial convolutional layers frozen. These layers capture universal visual features like edges, textures, and basic shapes, with filters trained to detect low-level, general characteristics useful across various visual tasks. Freezing these layers prevents their weights from updating, allowing the model to learn class-specific features without altering the general image representations already learned by ResNet50 (this minimizes overfitting) .e.g with ResNet50:
Initially, the model's initial convolutional layers are frozen to maintain their ability to capture universal visual features such as edges and textures and basic shapes. Freezing prevents updates to these layers' weights, restricting the "learning" to clss class-specific features without altering the general image representations learned by ResNet50 therefore minimizing the risk overfitting. .e.g. structure of ResNet50:
An input layer processing images,
A 7x7 convolutional layer with 64 filters and a 3x3 max pooling layer, which together extract basic visual features like edges and textures.
sequence of residual blocks:
- Conv Block 1: Three layers with 64, 64, 256 filters, repeated three times.
- Conv Block 2: Increases filters to 128, 128, 512, repeated four times.
- Conv Block 3: Escalates to 256, 256, 1024, with six repetitions.
- Conv Block 4: Peaks at 512, 512, 2048 filters, repeated three times.
- Average pooling Llyer: Reduces feature dimensionality.
- Fully connected layer: Executes final classification.
Conv Block 4 and the Fully Connected Layer are unfrozen during training to allow fine-tuning to dataset-specific features.
- Subsequently, the entire model is unfrozen for more comprehensive fine-tuning, including the pre-trained weights, over N epochs (using an early stopping callback to stop training after the validation loss stops decreasing).
Backpropagation updates weights based on loss gradients, supplemented by learning rate adjustments such as annealing to refine updates without significant deviations from the pretrained configuration.
These layers adapt their filters to better represent the unique features of our mushroom dataset (this would in theory involve more complex feature interactions than those in the original training dataset, like ImageNet)
Model Selection¶
ResNet18 - 18 layers deep with fewer filters and layers compared to ResNet34 and ResNet50. Specifically, it has 2 layers each in the first three sets of its convolutional blocks and 2 layers in the last set. ~11.7 million parameters.
Achieved up to F1= ~0.91 on the validation set
--
ResNet34 - 34 total, it has 3, 4, 6, and 3 layers in the four sets of its convolutional. ~ 21.8 million params
up to F1= ~0.92
--
ResNet50 - introduces bottleneck layers to reduce the computational burden. It has a different block structure with 1x1, 3x3, and 1x1 convolutions where the 1x1 layers are responsible for reducing and then increasing (restoring) dimensions, keeping the 3x3 layer a bottleneck with fewer input/output dimensions. ~25.6 million params.
up to F1 = ~0.95
Parameter Selection and Tuning¶
We've performed extensive tuning for the model (~50 trial for selecting appropriate hyperparameter and ~150 trials for selecting the optimal parameters) using Bayesian tuning, so that will be used as the basis of our model, however most of the parameters don't seem to have a signficant impact besides:
pct_startdefines the percentage of cycles for increasing the learning rate, impacting speed and effectiveness of neural network training adjustments.augmentations(i.e. image transformation like changing the scale, size, rotation of the image, using various other techniques like erasing parts of the image etc.). One issue with our results is that we tuned our model using discrete sets of transformations instead of tuning individual parameters.
Additionally:
- We've found that using class weighted to handle the imbalance in the dataset had no or limited effects so we're not employing any over or under sample technique (generating additional synthetic classes might be an option that could be explored).
- Tuning was only performed for the
ResNet50model
Selected Parameters:¶
{'batch_size': 64,
'base_lr': 0.0014050114695105182,
'weight_decay': 0.09308356639366534,
'lr_mult': 10,
'lr_scheduler': 'flat_cos',
'freeze_epochs': 6,
'pct_start': 0.39954565320066476,
'aug_mode': 'mult_1.25_more_trans'}
Model fine tuning and training¶
We've selected ResNet50 as our final "production" model because we were able to achieve signficantly better performance with it, however depending on the application and technical constraints this might not be the optimal choice:
- Tuning/training in a reasonable amount of time requires a relatively recent GPU with at least 16 GB or so of memory.
- However, relatively to more modern deeply learning models (especially LLMs etc.) memory requirements for inference are low and shouldn't exceed a few hundred MB even for ResNet50 for a single image and 50-100ms or so even on CPUs.
- This becomes a much more important issue when if we're working with live recognition/videos/AR rather than individual static images. In that case the 3-5x performance difference might become very significant when running on non high-end desktop/server level hardware e.g. AR and similar apps on mobile devices would generally use ResNet18, ResNet34 or more likely shallower models like MobileNet, EfficientNet etc which would have lower parameter count and be faster.
Augmentation Mode: mult_1.25_more_trans | Data Type: Train using RESNET50 <function resnet50 at 0x70262070eb90>
Model is on CUDA
| epoch | train_loss | valid_loss | accuracy | f1_score | f1_score | f1_score | time |
|---|---|---|---|---|---|---|---|
| 0 | 2.201875 | 1.372627 | 0.660160 | 0.649471 | 0.660160 | 0.599223 | 00:14 |
| 1 | 1.802115 | 1.206056 | 0.704525 | 0.698271 | 0.704525 | 0.659144 | 00:13 |
| 2 | 1.549524 | 1.134944 | 0.745342 | 0.740111 | 0.745342 | 0.710000 | 00:13 |
| 3 | 1.405611 | 1.135062 | 0.732032 | 0.728327 | 0.732032 | 0.690126 | 00:13 |
| 4 | 1.289279 | 1.090290 | 0.753327 | 0.749645 | 0.753327 | 0.726165 | 00:13 |
| 5 | 1.210203 | 1.057118 | 0.755102 | 0.750526 | 0.755102 | 0.719891 | 00:13 |
| epoch | train_loss | valid_loss | accuracy | f1_score | f1_score | f1_score | time |
|---|---|---|---|---|---|---|---|
| 0 | 1.095013 | 0.995483 | 0.819876 | 0.818026 | 0.819876 | 0.787816 | 00:16 |
| 1 | 0.925314 | 0.862077 | 0.858917 | 0.858179 | 0.858917 | 0.832653 | 00:16 |
| 2 | 0.829193 | 0.810338 | 0.877551 | 0.876764 | 0.877551 | 0.848774 | 00:16 |
| 3 | 0.755437 | 0.785751 | 0.881988 | 0.880593 | 0.881988 | 0.861716 | 00:16 |
| 4 | 0.703322 | 0.750437 | 0.900621 | 0.899623 | 0.900621 | 0.882583 | 00:15 |
| 5 | 0.676031 | 0.755283 | 0.891748 | 0.891298 | 0.891748 | 0.875207 | 00:16 |
| 6 | 0.643748 | 0.727963 | 0.910382 | 0.909308 | 0.910382 | 0.892095 | 00:16 |
| 7 | 0.621263 | 0.700396 | 0.921029 | 0.920402 | 0.921029 | 0.906657 | 00:16 |
| 8 | 0.610846 | 0.709269 | 0.919255 | 0.918628 | 0.919255 | 0.904956 | 00:16 |
| 9 | 0.597556 | 0.734286 | 0.888199 | 0.886879 | 0.888199 | 0.871929 | 00:16 |
| 10 | 0.592544 | 0.698643 | 0.916593 | 0.915706 | 0.916593 | 0.900124 | 00:16 |
| 11 | 0.591604 | 0.691277 | 0.918367 | 0.917716 | 0.918367 | 0.899072 | 00:16 |
| 12 | 0.585929 | 0.687158 | 0.924579 | 0.923806 | 0.924579 | 0.908399 | 00:16 |
| 13 | 0.576905 | 0.680047 | 0.921029 | 0.920304 | 0.921029 | 0.908931 | 00:16 |
| 14 | 0.565662 | 0.680570 | 0.921029 | 0.919747 | 0.921029 | 0.896938 | 00:16 |
| 15 | 0.553509 | 0.652810 | 0.937001 | 0.936754 | 0.937001 | 0.923529 | 00:16 |
| 16 | 0.545040 | 0.652485 | 0.933452 | 0.933213 | 0.933452 | 0.924044 | 00:16 |
| 17 | 0.538019 | 0.641254 | 0.935226 | 0.935154 | 0.935226 | 0.926492 | 00:16 |
| 18 | 0.533519 | 0.632918 | 0.943212 | 0.942863 | 0.943212 | 0.933724 | 00:16 |
| 19 | 0.527825 | 0.631027 | 0.940550 | 0.940248 | 0.940550 | 0.925694 | 00:16 |
| 20 | 0.525236 | 0.624010 | 0.943212 | 0.943094 | 0.943212 | 0.929062 | 00:15 |
| 21 | 0.519706 | 0.623547 | 0.944987 | 0.944514 | 0.944987 | 0.933056 | 00:16 |
| 22 | 0.520856 | 0.620551 | 0.943212 | 0.942946 | 0.943212 | 0.930493 | 00:16 |
| 23 | 0.518598 | 0.620019 | 0.944987 | 0.944672 | 0.944987 | 0.930896 | 00:16 |
| 24 | 0.516994 | 0.619290 | 0.947649 | 0.947328 | 0.947649 | 0.935211 | 00:16 |
<Axes: title={'center': 'learning curve'}, xlabel='steps', ylabel='loss'>
Overfitting does not seem to be a significant issue, validation loss was decreasing during the entire training process and the difference between train and validation loss is relatively low.
The table below show the classification metrics on full training, validation and testing samples with any augmentations turned off.
We are using a separate test sample because the validation sample was using during hyperparameter tuning and the selected parameters and augmentation options might be indirectly "overfitted" on the validation sample. However, considering that the dataset is relatively small using just 2 samples might be sufficient (ideally we'd also use CV).
Augmentation Mode: None | Data Type: Train No augmentations applied.
| Dataset | N | Loss | Accuracy | Weighted F1 | Micro F1 | Macro F1 | |
|---|---|---|---|---|---|---|---|
| 0 | Train | 4512 | 0.493925 | 1.000000 | 1.000000 | 1.000000 | 1.000000 |
| 1 | Validation | 1127 | 0.619294 | 0.947649 | 0.947328 | 0.947649 | 0.935211 |
| 2 | Test | 996 | 0.613501 | 0.947791 | 0.947677 | 0.947791 | 0.945465 |
By Class Performance¶
Augmentation Mode: mult_1.25_more_trans | Data Type: Test No augmentations applied.
| Class | Precision | Recall | F1 Score | Count | |
|---|---|---|---|---|---|
| 0 | Agaricus | 0.921569 | 0.886792 | 0.903846 | 53 |
| 1 | Amanita | 0.937500 | 0.937500 | 0.937500 | 112 |
| 2 | Boletus | 0.958084 | 0.993789 | 0.975610 | 161 |
| 3 | Cortinarius | 0.928000 | 0.928000 | 0.928000 | 125 |
| 4 | Entoloma | 0.903846 | 0.854545 | 0.878505 | 55 |
| 5 | Hygrocybe | 1.000000 | 0.978723 | 0.989247 | 47 |
| 6 | Lactarius | 0.938053 | 0.942222 | 0.940133 | 225 |
| 7 | Russula | 0.952941 | 0.947368 | 0.950147 | 171 |
| 8 | Suillus | 0.914894 | 0.914894 | 0.914894 | 47 |
Largest Losses (i.e. worst predictions)¶
Smallest Losses (i.e. best predictions)¶
Most Confused Mushroom Types (actual, predicted, n. occurrences)¶
[('Entoloma', 'Lactarius', 5),
('Lactarius', 'Russula', 5),
('Russula', 'Lactarius', 5),
('Amanita', 'Cortinarius', 3),
('Agaricus', 'Entoloma', 2),
('Amanita', 'Agaricus', 2),
('Boletus', 'Cortinarius', 2),
('Cortinarius', 'Amanita', 2),
('Cortinarius', 'Boletus', 2),
('Cortinarius', 'Lactarius', 2),
('Cortinarius', 'Russula', 2),
('Lactarius', 'Amanita', 2),
('Lactarius', 'Entoloma', 2),
('Russula', 'Agaricus', 2),
('Russula', 'Boletus', 2),
('Suillus', 'Boletus', 2)]
Explaining Predictions using LIME¶
LIME (Local Interpretable Model-agnostic Explanations) works by creating interpretable models around the predictions made by a complex model like ResNet. Basically it creates a very large sample of "perturbed" images which are used to identify how predictions change based on them), this allows it to highlight important areas that the decision of the main model was based on.
Device: cuda
Agaricus: 0.060 Amanita: 0.049 Boletus: 0.751 Cortinarius: 0.019 Entoloma: 0.021 Hygrocybe: 0.020 Lactarius: 0.040 Russula: 0.019 Suillus: 0.019
Agaricus: 0.017 Amanita: 0.046 Boletus: 0.021 Cortinarius: 0.015 Entoloma: 0.022 Hygrocybe: 0.021 Lactarius: 0.030 Russula: 0.814 Suillus: 0.013
Agaricus: 0.010 Amanita: 0.010 Boletus: 0.916 Cortinarius: 0.010 Entoloma: 0.011 Hygrocybe: 0.015 Lactarius: 0.010 Russula: 0.010 Suillus: 0.008
Agaricus: 0.900 Amanita: 0.008 Boletus: 0.011 Cortinarius: 0.007 Entoloma: 0.034 Hygrocybe: 0.010 Lactarius: 0.005 Russula: 0.012 Suillus: 0.013
Agaricus: 0.012 Amanita: 0.882 Boletus: 0.012 Cortinarius: 0.013 Entoloma: 0.016 Hygrocybe: 0.015 Lactarius: 0.015 Russula: 0.021 Suillus: 0.016
Agaricus: 0.013 Amanita: 0.009 Boletus: 0.011 Cortinarius: 0.007 Entoloma: 0.009 Hygrocybe: 0.923 Lactarius: 0.006 Russula: 0.010 Suillus: 0.013
Inference Performance¶
The table below shows inference performance depending on the number of images in a single batch. We can see that inference on the GPU is much more scalable and allows evaluating samples in parallels.
| Device | Sample Size | Total Time (s) | Inference Time (s) | Peak Memory Usage (MB) | Per Image (ms) | |
|---|---|---|---|---|---|---|
| 0 | cpu | 1 | 3.158689 | 0.287251 | 0.585938 | 287.25 |
| 1 | cpu | 10 | 3.887971 | 1.122948 | -4.468750 | 112.29 |
| 2 | cpu | 50 | 10.299034 | 7.485476 | -1.796875 | 149.71 |
| 3 | cpu | 100 | 14.920235 | 11.995297 | 2.695312 | 119.95 |
| 4 | cpu | 500 | 16.881310 | 13.978759 | 4.289062 | 27.96 |
| 5 | cuda | 1 | 2.920804 | 0.015073 | 2.000000 | 15.07 |
| 6 | cuda | 10 | 3.015016 | 0.016202 | 2.000000 | 1.62 |
| 7 | cuda | 50 | 2.940215 | 0.016177 | 308.000000 | 0.32 |
| 8 | cuda | 100 | 3.028011 | 0.016233 | 924.000000 | 0.16 |
| 9 | cuda | 500 | 3.133774 | 0.018235 | 1572.000000 | 0.04 |
On a GPU Inference itself seems to be more or less instantenous with most time spent in loading the image to memory. In this specific configuration CPU inference is several hundred times slower. This indicates that ResNet50 isn't really suitable for CPU inference and we should chose ResNet18 or 34 if that's a requirement.
Device info:
"GPU: NVIDIA GeForce RTX 3090, CPU: ('AMD EPYC 7702P 64-Core Processor', 128)"
Main Observations¶
Potential Issues¶
Model parameters were tuned using Optuna and multiple sessions:
- 40 trials for eliminating the least useful parameter values and ranges that resulted in significant performance degradation.
- 130 trials for final tuning.
The tuning process took around X.X hours on an RTX 3090, significantly improving performance from the baseline case (using
cnn_learner+lr_findto find the optimal learning rate by briefly training the model on a range of learning rates). From F1 ~= 0.915 to ~= 0.965 after:- Selecting optimal
batch_size,weight_decay,freeze_epochs, andpct_startvalues. - Changing LR scheduler to
flat_cos. - Adding additional augmentation transformations like: rotations, scaling, zooming, flipping, random erasing, warping, etc.
- The tuning process was very slow and inefficient; we should probably be able to improve this by using a more aggressive pruning strategy and combining it with
lr_findinstead of Bayesian optimization for selecting the optimal LR:- Additionally, we only used fixed sets of augmentations which limited the search space; we should consider tuning individual augmentation parameters as well.
- Selecting optimal
Overfitting¶
- Direct overfitting does not seem to have been a significant concern; the difference between
train_lossandvalid_losswas at most~0.065or less. - The model was hypertuned with Optuna using a fixed train-validation split and a fixed seed for training. This is not ideal because the sample is very small and likely resulted in Optuna indirectly overfitting on the validation set when selecting the optimal parameters. Ideally, we'd use CV but that would have extended the tuning time significantly.
Future Improvements:¶
- Improve sample selection and filtering and test performance on different subsamples i.e. we've found that there is a lot of variance between images for the same class (i.e. different backgrounds or different zoom levels or composition, like multiple vs individual mushrooms) we use this data to potentiall select more optimal training samples.